# -*- coding: utf-8 -*-
"""


@author: Xuezhi Cang xuezhicang@gmail.com
## create the optimal channel network
## the code runs on the HPC nodes
"""




import os
import shutil
import numpy as np
import copy
from collections import defaultdict
import sys
import multiprocessing 
import time
from numpy.random import RandomState
from scoop import futures

## instruction of coordinates
# i,j: coordinates of node (i,j)
# ID: ID of NODE i* size + j
# idx: index of node in the array

def cal_energy(node_info_container,energy_exp):
    
    energy = np.sum(node_info_container['facc']**energy_exp)
    return energy


def facc_cal(node_info_container,temp_ID,real_outlet_idx):
    
    temp_idx = np.where(node_info_container['node_id'] == temp_ID)[0][0] 

    if(np.sum(node_info_container['sources'][temp_idx]) == -8):
        node_info_container['facc'][temp_idx] = 1
        return 1

    else:
        valid_source_id = node_info_container['sources'][temp_idx][node_info_container['sources'][temp_idx]>-1]
        facc_cur = 1
        for i in valid_source_id:
            facc_cur = facc_cur + facc_cal(node_info_container,i,real_outlet_idx)        
        node_info_container['facc'][temp_idx] = facc_cur 
        if (temp_idx != real_outlet_idx):
            return facc_cur
        else:
            return node_info_container


def facc(node_info_container,temp_ID):

    
    temp_idx = np.where(node_info_container['node_id'] == temp_ID)[0][0] 
    real_outlet_idx = temp_idx 
    
    node_info_container = facc_cal(node_info_container,temp_ID,real_outlet_idx)
    return node_info_container


def update_flow_direction(node_info_container,NUM_DIM,SIZE_Y,SIZE_X):
    MOVE_STEP = [[-1,-1],[-1,0],[-1,1],[0,-1],[0,1],[1,-1],[1,0],[1,1]]
    node_info_container['sources'] = np.zeros((len(node_info_container['node_id']), 3 ** NUM_DIM - 1)).astype(np.int)  
    node_info_container['sources'] = node_info_container['sources'] -1    
    for id_idx,id_content in enumerate(node_info_container['node_id']):
        if(node_info_container['sink'][id_idx] != -1):
            sink_id = node_info_container['sink'][id_idx]
            sink_idx = np.where(node_info_container['node_id'] == sink_id)[0][0]
            temp_source = np.unravel_index(id_content,(SIZE_Y,SIZE_X))
            temp_sink = np.unravel_index(sink_id,(SIZE_Y,SIZE_X))
            
            neighbor_idx =  MOVE_STEP.index([temp_source[0]-temp_sink[0]\
                                             ,temp_source[1]-temp_sink[1]])
            
            node_info_container['sources'][sink_idx][neighbor_idx] = id_content
    
    return node_info_container
        



def initialize_OCN(node_info_container,OUTLET_ID,NUM_DIM,SIZE_Y,SIZE_X):
    #
    # create the initial flow direction
    #
    
    # initialize the flow direction of the outlet
    

    ROOK_NRIGHBORS_ID = [1,3,4,6]
    BISHOP_NRIGHBORS_ID = [0,2,5,7]
    KING_NRIGHBORS_ID = [0,1,2,3,4,5,6,7]
    # 0 1 2
    # 3 C 4
    # 5 6 7
    
       
        
    # set others nodes's flow direction    
    boundary_lst = [OUTLET_ID] 
    processed_nodes = set() #create a set to store the nodes which have been assigned the sink
    while(len(boundary_lst) > 0): 
        cur_node_ID = boundary_lst.pop(0) #temp is the node_id  
        processed_nodes.add(cur_node_ID)
        cur_node_idx = np.where(node_info_container['node_id'] == cur_node_ID)[0][0]
        
        # go traverse the cur_node_ID's rook neighbors
        for i in ROOK_NRIGHBORS_ID:        
            #if the node is in the mask AND the node has not been assigned the flow direction
            if(node_info_container['neighbours'][cur_node_idx][i] >= 0 and\
               node_info_container['neighbours'][cur_node_idx][i] != OUTLET_ID ):
               cur_neighbour_id = node_info_container['neighbours'][cur_node_idx][i]
               cur_neighbour_idx = np.where(node_info_container['node_id'] == cur_neighbour_id)[0][0]
               if(node_info_container['sink'][cur_neighbour_idx] == -1):
                    #print("water is on " + str(cur_neighbour_id))
                    # node_info_container['neighbours'][cur_node_idx][i]-->cur_node_ID (flow to cur_node_ID)
                    node_info_container['sink'][cur_neighbour_idx] = cur_node_ID 
                    boundary_lst.append(node_info_container['neighbours'][cur_node_idx][i])
        
        # exception: some node only connects with others diagonally 
        # if the boundary_lst is empty and the processed nodes is less than the total node
        if( len(boundary_lst) == 0 and len(processed_nodes) < len(node_info_container['node_id'])):
            #print("begin to look at the disgnal node")
            
            processed_nodes_lst = list(processed_nodes)
            idx_processed_nodes = 0
            while len(boundary_lst)==0:
                processed_idx = np.where(\
                    node_info_container['node_id'] == \
                             processed_nodes_lst[idx_processed_nodes])[0][0]
                for i_bishop_neighbor in BISHOP_NRIGHBORS_ID:
                    cur_bishop_nerghbor = \
                        node_info_container['neighbours'][processed_idx][i_bishop_neighbor]
                    if(cur_bishop_nerghbor != -1 and cur_bishop_nerghbor not in processed_nodes):
                        boundary_lst.append(cur_bishop_nerghbor)
                        cur_bishop_nerghbor_idx = np.where(node_info_container['node_id'] == cur_bishop_nerghbor)[0][0]
                        node_info_container['sink'][cur_bishop_nerghbor_idx] = processed_nodes_lst[idx_processed_nodes]
                    
                idx_processed_nodes = idx_processed_nodes + 1
            
                    
    
    # update the sources
    node_info_container = update_flow_direction(node_info_container,NUM_DIM,SIZE_Y,SIZE_X)
    
    # update the flow accumulation
    OUTLET_idx = np.where(node_info_container['node_id'] == OUTLET_ID)[0][0] 
    node_info_container = facc(node_info_container,OUTLET_ID)
   
    
    return node_info_container
    

def OCN_creation(variables):
    
    
    i_th = variables[0]
    filename = variables[1]
    
    folder_location = variables[2] #"/data1/data2_new/marssim/z1747635/OCN/Earth_greatbasin/ws_" + filename.split(".")[0].split("_")[1]
    # creat folder of the folder does not exist
    os.makedirs(folder_location, exist_ok=True)
    
    energy_exp = variables[3]
    
    local_randomState = RandomState(int(i_th))
    
    starttime = time.time()
    ##############################
    #set the paremeter
    ##############################
    
    my_ws = np.genfromtxt(filename, delimiter=',').astype(int)
    
    size_my_data =  my_ws.shape
    SIZE_Y = size_my_data[0]
    SIZE_X = size_my_data[1]    


    for i_SIZE_Y in range(SIZE_Y):
        for i_SIZE_X in range(SIZE_X):
            if( my_ws[i_SIZE_Y][i_SIZE_X] == 2 ):
                OUTLET_X = i_SIZE_X
                OUTLET_Y = i_SIZE_Y
    
    OUTLET = (OUTLET_Y,OUTLET_X)

    DOMAIN_SIZE = (SIZE_Y,SIZE_X)
    RESEARCH_AREA_ENVELOPE = np.zeros((SIZE_Y,SIZE_X))
    NUM_DIM = 2
    
    INITIAL_COLLING_PHASE = 0 # this value (INITIAL_COLLING_PHASE) should smaller than 0.3
    COOLING_RATE = 1 # this value (COOLING_RATE) should be between between 0.5 and 10

    #
    # create a region mask, the watershed is in this mask area
    # TODO: input the real research area
    REGION_MASK = np.zeros((SIZE_Y,SIZE_X))
    for i in range(SIZE_Y):
        for j in range(SIZE_X):
            if(my_ws[i][j]>0):
                REGION_MASK[i][j] = 1

    ##############################
    #initialize the OCN data structure
    #############################
    OUTLET_Y = OUTLET[0]
    OUTLET_X = OUTLET[1] 
    OUTLET_ID =  np.ravel_multi_index((OUTLET_Y,OUTLET_X),(SIZE_Y,SIZE_X))
    
    #prepare the data
 
    # find the flow direction of the OUTLET
    node_info_container = defaultdict() #node_info_container is the container for all the nodes' info
    
    # create the Node's ID
    valid_nodes_arr = []
    for i in range(SIZE_Y) :
        for j in range(SIZE_X):
            if(REGION_MASK[i][j] == 1):
                valid_nodes_arr.append(np.ravel_multi_index((i,j), DOMAIN_SIZE))          
    node_info_container['node_id'] = np.array(valid_nodes_arr)
    
    # find the neighbors's id of each node.
    node_info_container['neighbours'] = np.zeros((len(valid_nodes_arr), 3 ** NUM_DIM -1)).astype(np.int)
    node_info_container['neighbours'] = node_info_container['neighbours'] -1
    MOVE_STEP = [[-1,-1],[-1,0],[-1,1],[0,-1],[0,1],[1,-1],[1,0],[1,1]]
    for i_idx,i_content in enumerate(node_info_container['node_id']):
        temp_center = np.unravel_index(i_content,(SIZE_Y,SIZE_X)) #get the (i,j) of temp center
        for j_idx,j_content in enumerate(MOVE_STEP): 
            new_i = temp_center[0] + j_content[0]
            new_j = temp_center[1] + j_content[1]
            if(new_i >=0 and new_i <SIZE_Y \
               and new_j >=0 and new_j <SIZE_X \
                   and REGION_MASK[new_i][new_j] == 1):
                node_info_container['neighbours'][i_idx][j_idx] = np.ravel_multi_index((new_i,new_j), DOMAIN_SIZE)
            else:
                node_info_container['neighbours'][i_idx][j_idx] = -1 # at this direction, the neighbor does not exist
    

    ############################################                                                        
    # innitial the receiver of donor of each node 
    ############################################         
    node_info_container['sink'] = np.zeros(len(node_info_container['node_id'])).astype(np.int64)
    node_info_container['sink'] = node_info_container['sink'] -1;
    
    node_info_container['sources'] = np.zeros((len(node_info_container['node_id']), 3 ** NUM_DIM - 1)).astype(np.int)  
    node_info_container['sources'] = node_info_container['sources'] -1
    
    node_info_container['facc'] = np.ones(len(node_info_container['node_id'])).astype(np.int64)
    
    # initilized the OCN 
    node_info_container = initialize_OCN(node_info_container,OUTLET_ID,NUM_DIM,SIZE_Y,SIZE_X) 


    
    ###############################
    # update the energy
    ###############################
    #print(node_info_container['facc'])
    cur_energy = cal_energy(node_info_container,energy_exp)#calculate the current energy
    
    n_Iteration = 40 * SIZE_Y * SIZE_X 
    Temperature_arr = np.zeros(n_Iteration).astype(np.float64)
    Temperature_arr  = Temperature_arr -1
    
    Energy_arr  = np.zeros(n_Iteration).astype(np.float64)
    Energy_arr  = Energy_arr - 1
    
    
    
    # assigning the tempertation value
    for i_temper_idx,i_temper_content in enumerate(Temperature_arr): 
        if(i_temper_idx <= INITIAL_COLLING_PHASE * n_Iteration):
            Temperature_arr[i_temper_idx] = cur_energy
        else:
            Temperature_arr[i_temper_idx] = cur_energy * \
                (np.exp((-1*COOLING_RATE)*(i_temper_idx - INITIAL_COLLING_PHASE * n_Iteration )/ len(node_info_container['node_id'])))
    Temperature_arr = Temperature_arr + 0.000000000000001            
        
    # store the nodeID and regionMask
    np.savetxt(folder_location + "/" +"nodeID_" + filename.split(".")[0].split("_")[1]+"_"+str(i_th)+  ".csv", node_info_container['node_id'], delimiter=",")                 
                
                
    # choose a randomfunction to see whether it can be change
    #   if so, change it and recalculate the energy
    # simulated annealing part
    for i in range(n_Iteration):
        
        
        Energy_arr[i] = cur_energy 
        node_info_container_keep = copy.deepcopy(node_info_container) #store the current nodes' info   
        # select the node
        random_sel_ID = local_randomState.choice(node_info_container['node_id'], 1)[0]
        #print(str(random_sel_ID) + "  node is selected")
        
        #is the random selected point is outlet, the selection is discarded
        if(random_sel_ID == OUTLET_ID):
            continue
        random_sel_idx = np.where(node_info_container['node_id'] == random_sel_ID)[0][0]
    
    
        #
        # select its direction
        #
        
        # select the random selected node's neighbors
        neis_random_sel_IDs =  node_info_container['neighbours'][random_sel_idx]
        
        # selected the random selected node's source and sink
        sink_random_sel_id = node_info_container['sink'][random_sel_idx]
        source_random_sel_id = node_info_container['sources'][random_sel_idx]\
            [node_info_container['sources'][random_sel_idx]>-1]       
        #print(str(sink_random_sel_id) + "  is old sink")
        
        #kick out the source and sink from the neighbor
        irrelative_neis = []
        for i_id in neis_random_sel_IDs:
            if ((i_id not in source_random_sel_id) \
                and (i_id != sink_random_sel_id) \
                and (i_id != -1)):
                irrelative_neis.append(i_id)
        
        if(len(irrelative_neis) >0 ):
            # select a irreletive node as the sink of the random selected node
            random_destination = local_randomState.choice(irrelative_neis,1)[0]
            node_info_container['sink'][random_sel_idx] = random_destination
            #print("new sink is " + str(random_destination))
            
        else:
            # if the nodes have no irreletive node, the original sink is still the sink 
            node_info_container['sink'][random_sel_idx] = sink_random_sel_id
            #print("keep the old sink " + str(sink_random_sel_id))
    
            
        #if the selected change is cross-flow 
        #   discard the change      
        temp_ran_source_ID = random_sel_ID
        temp_ran_sink_ID = node_info_container['sink'][random_sel_idx]
        temp_ran_sink_idx = np.where(node_info_container['node_id'] == temp_ran_sink_ID)[0][0]
        
        temp_ran_source_ID_i_j = np.unravel_index(temp_ran_source_ID,(SIZE_Y,SIZE_X))
        temp_ran_sink_ID_i_j = np.unravel_index(temp_ran_sink_ID,(SIZE_Y,SIZE_X))
        
        # ｉｆ　ｔｈｅ　ｓｅｌｅｃｔ　ｎｏｄｅ　ａｎｄ　ｉｔｓ　ｓｅｌｅｃｔｅｄ　ｎｅｉｇｈｂｏｒ　ａｒｅ　diagonal　ｃｏｎｎｅｃｔｅｄ
        if(abs(temp_ran_source_ID_i_j[0]-temp_ran_sink_ID_i_j[0]) == 1 and\
           abs(temp_ran_source_ID_i_j[1]-temp_ran_sink_ID_i_j[1]) == 1):
            
            # the orthogonal diagonal nodes 
            temp_diagonal_1_i_j = (temp_ran_source_ID_i_j[0],temp_ran_sink_ID_i_j[1])  
            temp_diagonal_2_i_j = (temp_ran_sink_ID_i_j[0],temp_ran_source_ID_i_j[1])
            
            # get the NodeID of the orthogonal diagonal nodes 
            temp_diagonal_1_ID = \
                np.ravel_multi_index((temp_diagonal_1_i_j[0],temp_diagonal_1_i_j[1]),\
                                     DOMAIN_SIZE)
            temp_diagonal_2_ID = \
                np.ravel_multi_index((temp_diagonal_2_i_j[0],temp_diagonal_2_i_j[1]),\
                                     DOMAIN_SIZE)
            # if orthogonal diagonal nodes are vilid nodes, get the idx of the orthogonal diagonal nodes  
            if((temp_diagonal_1_ID in node_info_container['node_id']) and (temp_diagonal_2_ID in node_info_container['node_id'])):
                temp_diagonal_1_idx = np.where(node_info_container['node_id'] == temp_diagonal_1_ID)[0][0]
                temp_diagonal_2_idx = np.where(node_info_container['node_id'] == temp_diagonal_2_ID)[0][0]
                # if cross-flow appears
                #   continue
                if((node_info_container['sink'][temp_diagonal_1_idx] == temp_diagonal_2_ID ) or\
                   (node_info_container['sink'][temp_diagonal_2_idx] == temp_diagonal_1_ID )):
                    node_info_container= copy.deepcopy(node_info_container_keep)
                    continue
        #print("cross-flow test passes")   
            
        #if the selected change is a loop
        #   discard the change     
        loop_flag = False
        looping_test_node = temp_ran_sink_ID # node_info_container['sink'][temp_ran_sink_idx]
        while((loop_flag == False) and (looping_test_node != OUTLET_ID)):
    
            if(looping_test_node == random_sel_ID):
                loop_flag = True
            
            looping_test_node_idx = np.where(node_info_container['node_id'] == looping_test_node)[0][0]
            looping_test_node = node_info_container['sink'][looping_test_node_idx]   
        
        if (loop_flag == True):
            node_info_container= copy.deepcopy(node_info_container_keep)
            continue
        #print("loop test passes")  

        #update the flow direction

        node_info_container = update_flow_direction(node_info_container,NUM_DIM,SIZE_Y,SIZE_X)

        #update the flow accumulation

        #node_info_container['facc'] = np.ones(len(node_info_container['node_id'])).astype(np.int64)
        #node_info_container = facc(node_info_container,OUTLET_ID)
        
        

        random_sel_idx = np.where(node_info_container['node_id'] == random_sel_ID)[0][0]
        temp_facc = node_info_container['facc'][random_sel_idx]
        
        temp_next_ID = node_info_container['sink'][random_sel_idx]
        while(temp_next_ID != OUTLET_ID):
            temp_next_idx = np.where(node_info_container['node_id'] == temp_next_ID)[0][0]
            node_info_container['facc'][temp_next_idx] = node_info_container['facc'][temp_next_idx] +temp_facc
            temp_next_ID = node_info_container['sink'][temp_next_idx]        
        temp_next_idx = np.where(node_info_container['node_id'] == OUTLET_ID)[0][0]
        node_info_container['facc'][temp_next_idx] = node_info_container['facc'][temp_next_idx] +temp_facc
    
        temp_next_ID = node_info_container_keep['sink'][random_sel_idx]
        while(temp_next_ID != OUTLET_ID):
            temp_next_idx = np.where(node_info_container_keep['node_id'] == temp_next_ID)[0][0]
            node_info_container['facc'][temp_next_idx] = node_info_container['facc'][temp_next_idx] - temp_facc
            temp_next_ID = node_info_container_keep['sink'][temp_next_idx]        
        temp_next_idx = np.where(node_info_container['node_id'] == OUTLET_ID)[0][0]
        node_info_container['facc'][temp_next_idx] = node_info_container['facc'][temp_next_idx] - temp_facc
    
        
        
        # check whether the new pattern is accepted
        #print(node_info_container["sink"])
        #print(node_info_container["facc"])
        new_energy = cal_energy(node_info_container,energy_exp)
        #print("old energy is " + str(cur_energy))
        #print("new energy is " + str(new_energy))
        #if(new_energy > cur_energy):
        #    print("possibility is " + str(np.exp((cur_energy - new_energy )/Temperature_arr[i])))
    
        if (\
            (new_energy <= cur_energy) or \
            (np.random.uniform(0,1,1)[0]<np.exp(( cur_energy - new_energy)/Temperature_arr[i]))\
                ):
            cur_energy = new_energy #update the energy
            
        else:
            #print("keep_old")
    
            node_info_container= copy.deepcopy(node_info_container_keep) 
        
        
        #print(cur_energy)
        #print()
        '''
        if(i>n_Iteration*0.5 and i%200 ==0 ):
            np.savetxt(folder_location + "/" +"facc_" + filename.split(".")[0].split("_")[1]+"_"+ str(i_th)+ "_temp"+ ".csv", node_info_container['facc'], delimiter=",")    
            np.savetxt(folder_location + "/" +"energy_" + filename.split(".")[0].split("_")[1]+"_"+ str(i_th)+ "_temp"+ ".csv", np.array(Energy_arr), delimiter=",")  
            np.savetxt(folder_location + "/" +"sink_" + filename.split(".")[0].split("_")[1]+"_"+str(i_th)+ "_temp"+  ".csv", node_info_container['sink'], delimiter=",") 
        '''                
   
    np.savetxt(folder_location + "/" +"facc_" + filename.split(".")[0].split("_")[1]+"_"+str(i_th)+  ".csv", node_info_container['facc'], delimiter=",")    
    np.savetxt(folder_location + "/" +"energy_" + filename.split(".")[0].split("_")[1]+"_"+ str(i_th)+ ".csv", np.array(Energy_arr), delimiter=",")  
    np.savetxt(folder_location + "/" +"sink_" + filename.split(".")[0].split("_")[1]+"_"+str(i_th)+  ".csv", node_info_container['sink'], delimiter=",")                 
 
    #print("the " + str(i_th) +"th simulation is complete") 
    print('the ' + str(filename.split(".")[0].split("_")[1]) +' ' + str(i_th) +' took {} seconds'.format(time.time() - starttime))
    
    
    return node_info_container



  

        
if __name__ == '__main__':
    
    # parallel processing use the SCOOP libry
   
    energy_exp = 0.7
    too_large_ws_threshold = 2500
    too_small_ws_threshold = 100
    
    location_folder = sys.argv[1] #"/data1/data2_new/marssim/z1747635/OCN/test/"
    location_result = location_folder + "ws_" 
    repeat_time = int(sys.argv[2]) 
    
    
    # creat two folders:
    #   one for the watersheds too large; one for the watershed too small
    toolarge_w_folder = "toolarge_ws"
    os.makedirs(toolarge_w_folder, exist_ok=True)
    toosmall_ws_folder = "toosmall_ws"
    os.makedirs(toosmall_ws_folder, exist_ok=True)    
    
    starttime = time.time()    
    
    # select all the wsmat files;
    # store the wsmat name to the ws_files
    files_in_folder = os.listdir()     
    ws_files = []
    for i_ws_file in files_in_folder:
        if(i_ws_file[0:2] == "ws"):
            ws_files.append(i_ws_file)
    
    # THE SIZE OF WSMAT IS CHECKED IN THE PREPROCESSING PROGRAM
    ws_files_propersize = []
    for i_ws_file in ws_files:
        ws_files_propersize.append(i_ws_file)
                  

    input_variables = []
    for ith_ws_file in ws_files_propersize:
        for ith_repeat_time in range(repeat_time):
            folder_location = \
            location_result + \
            ith_ws_file.split(".")[0].split("_")[1]
            input_variables.append([ith_repeat_time,ith_ws_file,folder_location,energy_exp])
            
        

    returnValues = list(futures.map(OCN_creation, input_variables))
    print('That total took {} seconds'.format(time.time() - starttime))    
